"""A general gymnasium environment for dynamic foraging tasks in AIND.

Adapted from Han's code for the project in Neuromatch Academy: Deep Learning
https://github.com/hanhou/meta_rl/blob/bd9b5b1d6eb93d217563ff37608aaa2f572c08e6/han/environment/dynamic_bandit_env.py

See also Po-Chen Kuo's implementation:
https://github.com/pckuo/meta_rl/blob/main/environments/bandit/bandit.py
"""

import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding
from typing import List

import logging


L = 0
R = 1
IGNORE = 2


class DynamicForagingTaskBase(gym.Env):
    """
    A general gymnasium environment for dynamic bandit task

    Adapted from https://github.com/thinkjrs/gym-bandit-environments/blob/master/gym_bandits/bandit.py  # noqa E501
    """

    def __init__(
        self,
        reward_baiting: bool = False,  # Whether the reward is baited
        allow_ignore: bool = False,  # Allow the agent to ignore the task
        num_arms: int = 2,  # Number of arms in the bandit
        num_trials: int = 1000,  # Number of trials in the session
        seed=None,
    ):
        """Init"""
        self.num_trials = num_trials
        self.max_episode_steps = num_trials  # for compatibility with gym
        self.reward_baiting = reward_baiting
        self.num_arms = num_arms
        self.allow_ignore = allow_ignore

        # State space
        # - Time (trial number) is the only observable state to the agent
        self.observation_space = spaces.Dict(
            {
                "trial": spaces.Box(low=0, high=self.num_trials, dtype=np.int64),
            }
        )

        # Action space
        num_actions = num_arms + int(allow_ignore)  # Add the last action as ignore if allowed
        self.action_space = spaces.Discrete(num_actions)

        # Random seed
        self.rng = np.random.default_rng(seed)

    def reset(self, seed=None, options={}):
        """
        The reset method will be called to initiate a new episode.
        You may assume that the `step` method will not be called before `reset` has been called.
        Moreover, `reset` should be called whenever a done signal has been issued.
        This should *NOT* automatically reset the task! Resetting the task is
        handled in the wrapper.
        """
        # seed self.np_random
        # pass an integer for RHG right after the environment has been initialized 
        # and then never again
        super().reset(seed=seed)

        # Some mandatory initialization for any dynamic foraging task
        self.trial = np.array([0])
        self.trial_p_reward = np.empty((self.num_trials, self.num_arms))
        self.reward_assigned_before_action = np.zeros_like(
            self.trial_p_reward
        )  # Whether the reward exists in a certain trial before action
        self.reward_assigned_after_action = np.zeros_like(
            self.trial_p_reward
        )  # Whether the reward exists in a certain trial after action
        self.random_numbers = np.empty_like(
            self.trial_p_reward
        )  # Cache the generated random numbers

        self.action = np.empty(self.num_trials, dtype=int)
        self.reward = np.empty(self.num_trials)

        self.generate_new_trial()  # Generate a new p_reward for the first trial

        return self._get_obs(), self._get_info()

    def step(self, action):
        """
        Execute one step in the environment.
        Should return: (observation, reward, terminated, truncated, info)
        If terminated or truncated is true, the user needs to call reset().
        """
        # Action should be type integer in [0, num_arms-1] if not allow_ignore else [0, num_arms]
        assert self.action_space.contains(action)
        self.action[self.trial] = action

        # Generate reward
        reward = self.generate_reward(action)
        self.reward[self.trial] = reward

        # Decide termination before trial += 1
        terminated = bool((self.trial == self.num_trials - 1))  # self.trial starts from 0

        # State transition if not terminated (trial += 1 here)
        if not terminated:
            self.trial += 1  # tick time here
            self.generate_new_trial()

        return self._get_obs(), reward, terminated, False, self._get_info()

    def generate_reward(self, action):
        """Compute reward, could be overridden by subclasses for more complex reward structures"""

        # -- Refilling rewards on this trial --
        self.random_numbers[self.trial] = self.rng.uniform(0, 1, size=self.num_arms)
        reward_assigned = (
            self.random_numbers[self.trial] < self.trial_p_reward[self.trial]
        ).astype(float)

        # -- Reward baited from the last trial --
        if self.reward_baiting and self.trial > 0:
            reward_assigned = np.logical_or(
                reward_assigned, self.reward_assigned_after_action[self.trial - 1]
            ).astype(float)

        # reshape reward_assigned
        reward_assigned = reward_assigned.reshape(-1)

        # Cache the reward assignment
        self.reward_assigned_before_action[self.trial] = reward_assigned
        self.reward_assigned_after_action[self.trial] = reward_assigned

        # -- Reward delivery --
        if action == IGNORE:
            # Note that reward may be still refilled even if the agent ignores the trial
            return 0

        # Clear up the reward_assigned_after_action slot and return the reward
        self.reward_assigned_after_action[self.trial, action] = 0

        return reward_assigned[action]

    def generate_new_trial(self):
        """Generate p_reward for a new trial
        Note that self.trial already increased by 1 here
        """
        raise NotImplementedError("generate_next_trial() should be overridden by subclasses")

    def get_choice_history(self):
        """Return the history of actions in format that is compatible with other library such as
        aind_dynamic_foraging_basic_analysis
        """
        actions = self.action.astype(float)
        actions[actions == IGNORE] = np.nan
        return actions

    def get_reward_history(self):
        """Return the history of rewards in format that is compatible with other library such as
        aind_dynamic_foraging_basic_analysis
        """
        return self.reward

    def get_p_reward(self):
        """Return the reward probabilities for each arm in each trial which is compatible with
        other library such as aind_dynamic_foraging_basic_analysis
        """
        return self.trial_p_reward.T

    def _get_obs(self):
        """Return the observation"""
        return {"trial": self.trial}

    def _get_info(self):
        """
        Info about the environment that the agents is not supposed to know.
        For instance, info can reveal the index of the optimal arm,
        or the value of prior parameter.
        Can be useful to evaluate the agent's perfomance
        """
        return {
            "trial": self.trial,
            "task_object": self,  # Return the whole task object for debugging
        }

    def _seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]


#############################################################################
"""
Couple block task for dynamic bandit environment
This is very close to the task used in mice training.
"""
class CoupledBlockTask(DynamicForagingTaskBase):
    """Coupled block task for dynamic foraging

    This default setting roughly matches what has been used in this paper:
    https://www.sciencedirect.com/science/article/pii/S089662731930529X
    """

    def __init__(
        self,
        block_min: int = 40,  # Min block length
        block_max: int = 80,  # Max block length
        block_beta: int = 20,  # Time constant of exponential distribution (the larger the flatter)
        p_reward_pairs: List[List[float]] = None,  # List of reward probability pairs
        **kwargs,
    ):
        """Init"""
        super().__init__(**kwargs)

        if p_reward_pairs is None:
            p_reward_pairs = [
                [0.225, 0.225],  # 1:1
                [0.45 / 4 * 1, 0.45 / 4 * 3],  # 1:3
                [0.45 / 7 * 1, 0.45 / 7 * 6],  # 1:6
                [0.05, 0.40],  # 1:8
            ]

        self.block_min = block_min
        self.block_max = block_max
        self.block_beta = block_beta
        self.p_reward_pairs = [sorted(ps) for ps in p_reward_pairs]  # Always sort the input ps

    def reset(self, seed=None, options={}):
        """Reset the task"""

        # Add more initialization specific to this task
        self.block_starts = [0]  # Start of each block. The first block always starts at trial 0
        self.block_lens = []  # Lengths of each block
        self.block_p_reward = []  # Rwd prob of each block

        # Call the base class reset at the end
        return super().reset(seed=seed, options=options)

    def generate_new_trial(self):
        """Override the base class method to generate the next trial for coupled block task."""
        # Start a new block if necessary
        if self.trial == self.block_starts[-1]:
            self._next_block()

        # Append the current block's reward probability
        # Note that self.trial already increased by 1 here
        self.trial_p_reward[self.trial, :] = self.block_p_reward[-1]
        return self.trial_p_reward[-1, :]

    def _next_block(self):
        """
        Generate the next block
        """
        # Generate the block length
        self.block_lens.append(
            int(
                generate_trunc_exp(
                    self.block_min,
                    self.block_max,
                    self.block_beta,
                    rng=self.rng,
                )[0]
            )
        )
        self.block_starts.append(self.block_starts[-1] + self.block_lens[-1])

        # Generate the reward probability
        self.block_p_reward.append(self._generate_block_p_reward())
        return

    def _generate_block_p_reward(self):
        """
        Generate the reward probability for the next block.
        """
        # If it is the first block, randomly choose a pair and the side
        if len(self.block_p_reward) == 0:
            p_reward = self.rng.choice(self.p_reward_pairs)
            p_reward = self._flip_side(p_reward, None)
            return p_reward

        # Else, generate a new p_reward based on the current p_reward
        # 1. if current p_L == p_R, randomly choose a p_reward_pair (excluding p_L == p_R)
        #    and make sure the new block is flipped compare
        #    to the one before the equal-probability block
        # 2. else, randomly choose a p_reward_pair and always flip the side
        if self.block_p_reward[-1][L] == self.block_p_reward[-1][R]:
            # Cannot be p_L == p_R again
            valid_pairs = [p for p in self.p_reward_pairs if p[L] != p[R]]
            # Randomly choose from the valid pairs
            p_reward = self.rng.choice(valid_pairs)
            # If there is a block before the equal-probability block, flip relative to it
            # otherwise, randomly choose
            p_reward = self._flip_side(
                p_reward, self.block_p_reward[-2] if len(self.block_p_reward) > 1 else None
            )
        else:
            # Randomly choose from any pairs
            p_reward = self.rng.choice(self.p_reward_pairs)
            # Make sure the side is flipped
            p_reward = self._flip_side(p_reward, self.block_p_reward[-1])

        return p_reward

    def _flip_side(self, p_reward_new, p_reward_old=None):
        """
        Make sure the new block is flipped compare to the one before the equal-probability block.
        If old is None, flip it with a 0.5 probability.
        """
        should_flip = p_reward_old is None and self.rng.random() < 0.5
        if p_reward_old is not None:
            should_flip = (p_reward_new[L] < p_reward_new[R]) == (p_reward_old[L] < p_reward_old[R])

        return p_reward_new[::-1] if should_flip else p_reward_new


def generate_trunc_exp(lower, upper, beta, n=1, rng=None):
    """
    Generate n samples from a truncated exponential distribution
    """
    if rng is None:
        rng = np.random.default_rng()

    x = lower + rng.exponential(beta, n)
    x[x > upper] = upper
    return x


#############################################################################
"""
Random walk task for the dynamic bandit environment.
"""
class RandomWalkTask(DynamicForagingTaskBase):
    """
    Generate reward schedule with random walk

    (see Miller et al. 2021, https://www.biorxiv.org/content/10.1101/461129v3.full.pdf)
    """

    def __init__(
        self,
        p_min=[0, 0],  # The lower bound of p_L and p_R
        p_max=[1, 1],  # The upper bound
        sigma=[0.15, 0.15],  # The mean of each step of the random walk
        mean=[0, 0],  # The mean of each step of the random walk
        **kwargs,
    ) -> None:
        """Init"""
        super().__init__(**kwargs)

        if not isinstance(sigma, list):
            sigma = [sigma, sigma]  # Backward compatibility

        if not isinstance(p_min, list):
            p_min = [p_min, p_min]  # Backward compatibility

        if not isinstance(p_max, list):
            p_max = [p_max, p_max]  # Backward compatibility

        self.p_min, self.p_max, self.sigma, self.mean = p_min, p_max, sigma, mean

    def reset(self, seed=None, options={}):
        """Reset the task, remember to call the base class reset at the end."""
        self.hold_this_block = False

        return super().reset(seed=seed, options=options)

    def generate_new_trial(self):
        """Generate a new trial. Overwrite the base class method."""
        # Note that self.trial already increased by 1 here
        # print([self._generate_next_p(side) for side in [L, R]])
        # print(self.trial_p_reward.shape)
        self.trial_p_reward[self.trial, :] = np.array([self._generate_next_p(side) for side in [L, R]]).reshape(-1)

    def _generate_next_p(self, side):
        """Generate the p_side for the next trial."""
        if self.trial == 0:
            # Start with uniform distribution
            return self.rng.uniform(self.p_min[side], self.p_max[side])
        if self.hold_this_block:
            return self.trial_p_reward[self.trial - 1, side]

        # Else, take a random walk
        else:
            # NOTE: trial_p_reward outside of normal dist?
            # p = self.trial_p_reward[self.trial-1, side] + \
            #     self.rng.normal(self.mean[side], self.sigma[side])
            p = self.rng.normal(
                self.trial_p_reward[self.trial-1, side] + self.mean[side], 
                self.sigma[side]
            )

            p = p[0]  # Unpack the array
            p = min(self.p_max[side], max(self.p_min[side], p))  # Absorb at the boundary
            return p

    def plot_reward_schedule(self):
        """Plot the reward schedule and compute the auto-correlation."""
        trial_p_reward = np.array(self.trial_p_reward)

        fig, ax = plt.subplots(
            2, 2, figsize=[15, 7], sharex="col", gridspec_kw=dict(width_ratios=[4, 1], wspace=0.1)
        )

        for s, col in zip([L, R], ["r", "b"]):
            ax[0, 0].plot(trial_p_reward[:, s], col, marker=".", alpha=0.5, lw=2)
            ax[0, 1].plot(auto_corr(trial_p_reward[:, s]), col)

        ax[1, 0].plot(trial_p_reward[:, L] + trial_p_reward[:, R], label="sum")
        ax[1, 0].plot(
            trial_p_reward[:, R] / (trial_p_reward[:, L] + trial_p_reward[:, R]),
            label="R/(L+R)",
        )
        ax[1, 0].legend()

        ax[0, 1].set(title="auto correlation", xlim=[0, 100])
        ax[0, 1].axhline(y=0, c="k", ls="--")

        return fig


def auto_corr(data):
    """Util function to compute the auto-correlation of the data."""
    mean = np.mean(data)
    # Variance
    var = np.var(data)
    # Normalized data
    ndata = data - mean
    acorr = np.correlate(ndata, ndata, "full")[len(ndata) - 1 :]  # noqa E203
    acorr = acorr / var / len(ndata)
    return acorr


#############################################################################
"""
Uncoupled task for dynamic bandit environment
see /test/test_uncoupled_block_task.py for usage
"""

logger = logging.getLogger(__name__)


class UncoupledBlockTask(DynamicForagingTaskBase):
    """
    Generate uncoupled block reward schedule
    (by on-line updating)

    adapted from Cohen lab's Arduino code (with some bug fixes)
    https://github.com/JeremiahYCohenLab/arduinoLibraries/blob/master/libraries/task_operantMatchingDecoupledBait/task_operantMatchingDecoupledBait.cpp  # noqa E501

    See Grossman et al. 2022:

    In the final stage of the task, the reward probabilities assigned to each lick spout were drawn
    pseudorandomly from the set {0.1, 0.5, 0.9} in all the mice from the behavior experiments (n=46),
    all the mice from the DREADDs experiments (n=10), and half of the mice from the electrophysiology
    experiments (n=2). The other half of mice from the electrophysiology experiments (n=2) were run
    on a version of the task with probabilities drawn from the set {0.1, 0.4, 0.7}. The probabilities
    were assigned to each spout individually with block lengths drawn from a uniform distribution
    of 20–35 trials. To stagger the blocks of probability assignment for each spout, the block length
    for one spout in the first block of each session was drawn from a uniform distribution of 6–21
    trials. For each spout, probability assignments could not be repeated across consecutive blocks.
    To maintain task engagement, reward probabilities of 0.1 could not be simultaneously assigned
    to both spouts. If one spout was assigned a reward probability greater than or equal to the reward
    probability of the other spout for 3 consecutive blocks, the probability of that spout was set to
    0.1 to encourage switching behavior and limit the creation of a direction bias. If a mouse
    perseverated on a spout with a reward probability of 0.1 for 4 consecutive trials, 4 trials were
    added to the length of both blocks. This procedure was implemented to keep mice from choosing
    one spout until the reward probability became high again.

    """

    def __init__(
        self,
        rwd_prob_array=[0.1, 0.5, 0.9],
        block_min=20,
        block_max=35,
        persev_add=True,
        perseverative_limit=4,
        max_block_tally=4,  # Max number of consecutive blocks in which one side is better
        **kwargs,
    ) -> None:
        """Init"""
        super().__init__(**kwargs)

        self.rwd_prob_array = rwd_prob_array
        self.block_min = block_min
        self.block_max = block_max
        self.persev_add = persev_add
        self.perseverative_limit = perseverative_limit
        self.max_block_tally = max_block_tally

        self.block_stagger = int((round(block_max - block_min - 0.5) / 2 + block_min) / 2)

    def reset(self, seed=None, options={}):
        """Reset the task"""
        self.rwd_tally = [0, 0]  # List for 'L' and 'R'

        self.block_ends = [[], []]  # List for 'L' and 'R', Trial number on which each block ends
        self.block_rwd_prob = [[], []]  # List for 'L' and 'R', Reward probability
        self.block_ind = [
            0,
            0,
        ]  # List for 'L' and 'R', Index of current block (= len(block_end_at_trial))

        self.force_by_tally = [[], []]  # List for 'L' and 'R'
        self.force_by_both_lowest = [[], []]  # List for 'L' and 'R'

        # Anti-persev
        self.persev_consec_on_min_prob = [0, 0]  # List for 'L' and 'R'
        self.persev_add_at_trials = []

        # Manually block hold
        self.hold_this_block = False

        return super().reset(seed=seed, options=options)

    def generate_new_trial(self):
        """Generate a new trial. Overwrite the base class method."""
        msg = ""

        if self.trial == 0:
            self.generate_first_block()

        # Block switch?
        if not self.hold_this_block:
            for s in [L, R]:
                if self.trial >= self.block_ends[s][self.block_ind[s]]:
                    # In case a block is mannually 'held', update the actual block transition
                    self.block_ends[s][self.block_ind[s]] = self.trial

                    self.block_ind[s] += 1
                    self.block_effective_ind += 1
                    msg = (
                        self.generate_next_block(
                            s, check_higher_in_a_row=True, check_both_lowest=True
                        )
                        + "\n"
                    )

        # Fill new value
        self.trial_p_reward[self.trial, :] = [
            self.block_rwd_prob[L][self.block_ind[L]],
            self.block_rwd_prob[R][self.block_ind[R]],
        ]

        # Anti-persev
        if not self.hold_this_block and self.persev_add and self.trial > 0:
            msg = msg + self.auto_shape_perseverance()
        else:
            for s in [L, R]:
                self.persev_consec_on_min_prob[s] = 0

        assert self.block_ind[L] + 1 == len(self.block_rwd_prob[L]) == len(self.block_ends[L])
        assert self.block_ind[R] + 1 == len(self.block_rwd_prob[R]) == len(self.block_ends[R])

        return (
            [
                self.trial_p_reward[self.trial - 1, s] != self.trial_p_reward[self.trial, s]
                for s in [L, R]
            ]  # Whether block just switched
            if self.trial > 0
            else [0, 0]
        ), msg

    def generate_first_block(self):
        """Generate the first block. Note the stagger is applied."""
        for side in [L, R]:
            self.generate_next_block(side)

        # Avoid both blocks have the lowest reward prob
        while np.all([x[0] == np.min(self.rwd_prob_array) for x in self.block_rwd_prob]):
            self.block_rwd_prob[self.rng.choice([L, R])][0] = self.rng.choice(
                self.rwd_prob_array
            )  # Random change one side to another prob

        # Start with block stagger: the lower side makes the first block switch earlier
        smaller_side = np.argmin([self.block_rwd_prob[L][0], self.block_rwd_prob[R][0]])
        self.block_ends[smaller_side][0] -= self.block_stagger

        self.block_effective_ind = 1  # Effective block ind

    def generate_next_block(self, side, check_higher_in_a_row=True, check_both_lowest=True):
        """Generate the next block for both sides (yes, very complicated logic...)"""
        msg = ""
        other_side = R if side == L else L
        random_block_len = self.rng.integers(low=self.block_min, high=self.block_max + 1)

        if self.block_ind[side] == 0:  # The first block
            self.block_ends[side].append(random_block_len)
            self.block_rwd_prob[side].append(self.rng.choice(self.rwd_prob_array))

        else:  # Not the first block
            self.block_ends[side].append(
                random_block_len + self.block_ends[side][self.block_ind[side] - 1]
            )

            # If this side has higher prob for too long, force it to be the lowest
            if check_higher_in_a_row:
                # For each effective block, update number of times each side >= the other side
                this_prev = self.block_rwd_prob[side][self.block_ind[side] - 1]
                other_now = self.block_rwd_prob[other_side][self.block_ind[other_side]]
                if this_prev > other_now:
                    self.rwd_tally[side] += 1
                    self.rwd_tally[other_side] = 0
                elif this_prev == other_now:
                    self.rwd_tally[side] += 1
                    self.rwd_tally[other_side] += 1
                else:
                    self.rwd_tally[other_side] += 1
                    self.rwd_tally[side] = 0

                if (
                    self.rwd_tally[side] >= self.max_block_tally
                ):  # Only check higher-in-a-row for this side
                    msg = (
                        f"--- {self.trial}: {side} is higher for {self.rwd_tally[side]} "
                        f"eff_blocks, force {side} to lowest ---\n"
                    )
                    logger.info(msg)
                    self.block_rwd_prob[side].append(min(self.rwd_prob_array))
                    self.rwd_tally[side] = self.rwd_tally[other_side] = 0
                    self.force_by_tally[side].append(self.trial)
                else:  # Otherwise, randomly choose one
                    self.block_rwd_prob[side].append(self.rng.choice(self.rwd_prob_array))
            else:
                self.block_rwd_prob[side].append(self.rng.choice(self.rwd_prob_array))

            # Don't repeat the previous rwd prob
            # (this will not mess up with the "forced" case since the previous block cannot be
            # the lowest prob in the first place)
            while self.block_rwd_prob[side][-2] == self.block_rwd_prob[side][-1]:
                self.block_rwd_prob[side][-1] = self.rng.choice(self.rwd_prob_array)

            # If the other side is already at the lowest prob AND this side just generates the same
            # (either through "forced" case or not), push the previous lowest side to a higher prob
            if check_both_lowest and self.block_rwd_prob[side][-1] == self.block_rwd_prob[
                other_side
            ][-1] == min(self.rwd_prob_array):
                # Stagger this side
                self.block_ends[side][-1] -= self.block_stagger

                # Force block switch of the other side
                msg += f"--- {self.trial}: both side is the lowest, push {side} to higher ---"
                logger.info(msg)
                self.force_by_both_lowest[side].append(self.trial)
                self.block_ends[other_side][-1] = self.trial
                self.block_ind[
                    other_side
                ] += (
                    1  # Two sides change at the same time, no need to add block_effective_ind twice
                )
                self.generate_next_block(
                    other_side, check_higher_in_a_row=False, check_both_lowest=False
                )  # Just generate new block, no need to do checks
        return msg

    def auto_shape_perseverance(self):
        """Anti-perseverance mechanism

        See Grossman et al. 2022:

        If a mouse perseverated on a spout with reward probability of 0.1 for 4 consecutive trials,
        4 trials were added to the length of both blocks. This procedure was implemented to keep
        mice from choosing one spout until the reward probability became high again.
        """
        msg = ""
        for s in [L, R]:
            if self.action[self.trial - 1] == s:  # Note that self.trial already increased 1
                self.persev_consec_on_min_prob[1 - s] = (
                    0  # Reset other side as soon as there is an opposite choice
                )
                if self.trial_p_reward[self.trial - 1, s] == min(
                    self.rwd_prob_array
                ):  # If last choice is on side with min_prob (0.1), add counter
                    self.persev_consec_on_min_prob[s] += 1

        for s in [L, R]:
            if self.persev_consec_on_min_prob[s] >= self.perseverative_limit:
                for ss in [L, R]:
                    self.block_ends[ss][
                        -1
                    ] += self.perseverative_limit  # Add 'perseverative_limit' trials to both blocks
                    self.persev_consec_on_min_prob[ss] = 0
                msg = f"persev at side = {s}, added {self.perseverative_limit} trials to both sides"
                logger.info(msg)
                self.persev_add_at_trials.append(self.trial)
        return msg

    def plot_reward_schedule(self):
        """Plot the reward schedule with annotations showing forced block switches"""
        fig, ax = plt.subplots(2, 1, figsize=[15, 7], sharex="col")

        def annotate_block(ax):
            """Annotate block switches with special handling"""
            for s, col in zip([L, R], ["r", "b"]):
                [
                    ax.axvline(x + (0.1 if s == R else 0), 0, 1, color=col, ls="--", lw=0.5)
                    for x in self.block_ends[s]
                ]

                ax.plot(
                    self.force_by_tally[s],
                    [1.2] * len(self.force_by_tally[s]),
                    marker=">",
                    ls="",
                    color=col,
                    label="forced by tally",
                )
                ax.plot(
                    self.force_by_both_lowest[s],
                    [1.1] * len(self.force_by_both_lowest[s]),
                    marker="v",
                    ls="",
                    color=col,
                    label="forced by both lowest",
                )

            for s, col, pos, m in zip(
                [L, R, IGNORE], ["r", "b", "k"], [0, 1, 0.95], ["|", "|", "x"]
            ):
                this_choice = np.where(self.action == s)
                ax.plot(this_choice, [pos] * len(this_choice), m, color=col)

            ax.plot(
                self.persev_add_at_trials,
                [1.05] * len(self.persev_add_at_trials),
                marker="+",
                label=f"anti-persev added {self.perseverative_limit} trials",
                ls="",
                color="c",
            )

        for s, col in zip([L, R], ["r", "b"]):
            ax[0].plot(self.trial_p_reward[:, s], col, marker=".", alpha=0.5, lw=2)
        annotate_block(ax[0])

        ax[1].plot(
            self.trial_p_reward.sum(axis=1),
            label="sum",
        )
        ax[1].plot(
            self.trial_p_reward[:, R] / self.trial_p_reward.sum(axis=1),
            label="R/(L+R)",
        )
        ax[1].legend()
        annotate_block(ax[1])

        return fig, ax